package com.kool.kmqtt.server.processer;

import com.kool.kmqtt.server.PacketSender;
import com.kool.kmqtt.server.ServerConfig;
import com.kool.kmqtt.server.constant.PacketTypeEnum;
import com.kool.kmqtt.server.encoder.SubackPacketEncoder;
import com.kool.kmqtt.server.exception.ErrorCode;
import com.kool.kmqtt.server.exception.ProtocolException;
import com.kool.kmqtt.server.packet.*;
import com.kool.kmqtt.server.parser.PacketParser;
import com.kool.kmqtt.server.repository.subscription.SubscribeInfo;
import io.netty.channel.ChannelHandlerContext;

import java.util.ArrayList;
import java.util.List;

/**
 * SUBSCRIBE报文处理器
 */
public class SubscribePacketProcessor extends PacketProcessor {
    public SubscribePacketProcessor(ChannelHandlerContext ctx, PacketParser packetParser) {
        super(ctx, packetParser);
    }

    @Override
    protected void validate(Packet packet) {
        //SUBSCRIBE控制报固定报头的第3,2,1,0位是保留位，必须分别设置为0,0,1,0
        if (packet.getFixedHeader().getFlags() != 2) {
            throw new ProtocolException(ErrorCode.SUBSCRIBE_FLAGS_ERROR);
        }

        //有效载荷必须包含至少一对主题过滤器 和 QoS等级字段组合
        SubscribePayload payload = (SubscribePayload) packet.getPayload();
        List<SubscribeInfo> subscribeInfos = payload.getSubscribeInfos();
        if (subscribeInfos == null || subscribeInfos.size() == 0) {
            throw new ProtocolException(ErrorCode.SUBSCRIBE_INFO_NULL);
        }
        //QoS必须等于0,1或2
        for (SubscribeInfo subscribeInfo : subscribeInfos) {
            if (subscribeInfo.getQos() > 2) {
                throw new ProtocolException(ErrorCode.SUBSCRIBE_QOS_ERROR);
            }
        }
    }

    @Override
    protected void processPacket(Packet packet) {
        SubscribePayload payload = (SubscribePayload) packet.getPayload();
        List<SubscribeInfo> subscribeInfos = payload.getSubscribeInfos();
        //客户端id
        String clientId = sessionContext.getClientId();

        List<Integer> codes = new ArrayList<>();
        for (SubscribeInfo subscribeInfo : subscribeInfos) {
            int qos = subscribeInfo.getQos() < ServerConfig.getInstance().getQos() ? subscribeInfo.getQos() : ServerConfig.getInstance().getQos();
            //保存主题过滤器-客户端订阅信息
            repository.saveSubscribeInfo(clientId, subscribeInfo);

            codes.add(qos);

            //查询主题过滤器匹配的保留消息
            String topicFilter = subscribeInfo.getTopicFilter();
            List<Packet> retainPackets = repository.getRetainPacket(topicFilter);
            if (retainPackets != null) {
                for (Packet retainPacket : retainPackets) {
                    retainPacket.getFixedHeader().setQoS(qos);
                    //如果消息是作为客户端一个新订阅的结果发送，它必须将报文的保留标志设为1
                    retainPacket.getFixedHeader().setRetain(true);
                    //发送保留消息
                    PublishUtil.sendPublish(retainPacket, qos, sessionContext);
                }
            }
        }

        FixedHeader fixedHeader = new FixedHeader();
        fixedHeader.setPacketType(PacketTypeEnum.SUBACK.getCode());
        fixedHeader.setFlags(0);
        fixedHeader.setRemainingLength(2 + codes.size());

        SubackVariableHeader subackVariableHeader = new SubackVariableHeader();
        subackVariableHeader.setPacketId(packet.getPacketId());

        SubackPayload subackPayload = new SubackPayload();
        subackPayload.setCodes(codes);

        Packet suback = new Packet();
        suback.setFixedHeader(fixedHeader);
        suback.setVariableHeader(subackVariableHeader);
        suback.setPayload(subackPayload);

        //发送SUBACK
        PacketSender packetSender = new PacketSender(sessionContext, new SubackPacketEncoder());
        packetSender.send(suback);
    }

}
